import sys, time
import numpy as np
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileProgram, compileShader

# ---------------------
# Compute Shader
# ---------------------
COMPUTE_SRC = """
#version 430
layout(local_size_x = 16, local_size_y = 16) in;

layout(rgba32f, binding = 0) uniform image2D latticeImage;

// Shared LUTs for fast access
shared float fibLUT[128];
shared float primeLUT[128];

uniform float phi;
uniform float phiInv;
uniform float cycle;
uniform float omegaTime;
uniform int numInstances;

void main() {
    ivec2 gid = ivec2(gl_GlobalInvocationID.xy);
    int idx = gid.y * 128 + gid.x; // map 2D → instance ID

    if (idx >= numInstances) return;

    // Initialize shared LUTs once per workgroup
    if (gl_LocalInvocationIndex < 128) {
        fibLUT[gl_LocalInvocationIndex] = pow((phi), gl_LocalInvocationIndex % 16);
        primeLUT[gl_LocalInvocationIndex] = float(2 + gl_LocalInvocationIndex*3 % 127);
    }
    barrier();

    float r = length(vec2(gid) / vec2(128,128) - 0.5) * 2.0;

    // Prismatic recursion superposition
    float phi_harm = pow(phi, idx % 16);
    float fib_harm = fibLUT[idx % 128];
    float dyadic = float(1 << (idx % 16));
    float prime_harm = primeLUT[idx % 128];
    float Omega = 0.5 + 0.5*sin(omegaTime + float(idx)*0.01);
    float r_dim = pow(r, float((idx % 7)+1));

    float val = sqrt(phi_harm * fib_harm * dyadic * prime_harm * Omega) * r_dim;

    // Phase coloring
    float phase = sin(cycle*0.01 + val);
    imageStore(latticeImage, gid, vec4(val, phase, r,1.0));
}
"""

# ---------------------
# Globals
# ---------------------
window = None
shader = None
texture = None
cycle = 0.0
omega_time = 0.0
num_instances = 16_384  # Adjust per GPU
img_size = 128

# ---------------------
# OpenGL Init
# ---------------------
def init_gl():
    global shader, texture

    shader = compileProgram(compileShader(COMPUTE_SRC, GL_COMPUTE_SHADER))

    # Texture to hold RGBA output
    texture = glGenTextures(1)
    glBindTexture(GL_TEXTURE_2D, texture)
    glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, img_size, img_size)
    glBindImageTexture(0, texture, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA32F)

    glUseProgram(shader)
    glUniform1f(glGetUniformLocation(shader,"phi"), 1.6180339887)
    glUniform1f(glGetUniformLocation(shader,"phiInv"), 0.6180339887)
    glUniform1i(glGetUniformLocation(shader,"numInstances"), num_instances)

# ---------------------
# Display
# ---------------------
def display():
    global cycle, omega_time

    glUseProgram(shader)
    glUniform1f(glGetUniformLocation(shader,"cycle"), cycle)
    glUniform1f(glGetUniformLocation(shader,"omegaTime"), omega_time)

    # Dispatch compute shader (tile workgroups)
    wg = (img_size + 15)//16
    glDispatchCompute(wg, wg, 1)
    glMemoryBarrier(GL_SHADER_IMAGE_ACCESS_BARRIER_BIT)

    # Simple screen draw of the texture
    glClear(GL_COLOR_BUFFER_BIT)
    glEnable(GL_TEXTURE_2D)
    glBindTexture(GL_TEXTURE_2D, texture)
    glBegin(GL_QUADS)
    glTexCoord2f(0,0); glVertex2f(-1,-1)
    glTexCoord2f(1,0); glVertex2f(1,-1)
    glTexCoord2f(1,1); glVertex2f(1,1)
    glTexCoord2f(0,1); glVertex2f(-1,1)
    glEnd()
    glutSwapBuffers()

    cycle += 1.0
    omega_time += 0.05

# ---------------------
# Idle
# ---------------------
def idle():
    glutPostRedisplay()

# ---------------------
# Main
# ---------------------
def main():
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(512,512)
    glutCreateWindow(b"Prismatic Compute Shader HDGL")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutMainLoop()

if __name__=="__main__":
    main()
